Introduction

This notebook walks you through the creation of an LSTM (Long Short Term Memory) model using TensorFlow. The model can be used to insert punctuations automatically on paragraphs without punctuations. For example, given:

i think it is a report which will for the most part be supported by my group

It produces:

i think it is a report which will , for the most part , be supported by my group .

An imaginary usage of the model is for typing --- you can type a bunch of words and let it insert puncuations for you. It may also be used in speech recognition.

The model does not rely on capitalization. All training and prediction data are converted to lowercase during data preparation.

Send any feedback to datalab-feedback@google.com.

Prepare Data

The training data used are europarl and comtran from NLTK Corpora. I think both are extracted from the proceedings of the European Parliament. I chose these two datasets because first they have clean punctuations, and second they are large enough to create a decent model.


In [13]:
# Download and unzip data.

!mkdir -p /content/datalab/punctuation/tmp
!mkdir -p /content/datalab/punctuation/data
!mkdir -p /content/datalab/punctuation/datapreped
!wget -q -P /content/datalab/punctuation/tmp/ https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/europarl_raw.zip
!wget -q -P /content/datalab/punctuation/tmp/ https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/comtrans.zip
!unzip -q -o /content/datalab/punctuation/tmp/europarl_raw.zip -d /content/datalab/punctuation/tmp
!unzip -q -o /content/datalab/punctuation/tmp/comtrans.zip -d /content/datalab/punctuation/tmp
!cp /content/datalab/punctuation/tmp/europarl_raw/english/* /content/datalab/punctuation/data

In [14]:
# We only need English from `comtran` data. Extract English text only.
with open('/content/datalab/punctuation/tmp/comtrans/alignment-en-fr.txt', 'r') as f_in, \
    open('/content/datalab/punctuation/data/comtrans.txt', 'w') as f_out:
  num_lines = 0
  for l in f_in.readlines():
    if num_lines == 0:
      f_out.write(l)
    num_lines = (0 if num_lines == 2 else num_lines + 1)

In [15]:
"""Prepare data by cleaning up text."""

import glob
import os
from random import randint
import re
import string


def prep_data(corpora_path, out_dir):
  """Clean up raw data and split them into train, validation, and test source."""
  
  printable = set(string.printable)
  all_corpora_files = glob.glob(corpora_path)
  lines = []

  for corpora_file in all_corpora_files:
    with open(corpora_file, 'r') as f:
      lines += f.readlines()
  
  dest_train = os.path.join(out_dir, 'train.txt') 
  dest_valid = os.path.join(out_dir, 'valid.txt')  
  dest_test = os.path.join(out_dir, 'test.txt') 

  valid_lines = 0
  test_lines = 0
  train_lines = 0

  with open(dest_train, 'w') as f_train, open(dest_valid, 'w') as f_valid, open(dest_test, 'w') as f_test:
    for l in lines:
      s = l.strip()
      # Remove "bad" sentences.
      if s.endswith(')') and s.startswith('('):
        continue
      if not s.endswith('.') and not s.endswith('!') and not s.endswith('?'):
        continue
      if s.find('...') != -1:
        continue

      # Remove quotes, apostrophes, leading dashes.        
      s = re.sub('"', '', s)
      s = re.sub(' \' s ', 's ', s)   
      s = re.sub('\'', '', s)
      s = re.sub('^- ', '', s)
      
      # Clean double punctuations.
      s = re.sub('\? \.', '\?', s)    
      s = re.sub('\! \.', '\!', s)
      
      # Extract human names to reduce vocab size. There are many names like 'Mrs Plooij-van Gorsel'
      # 'Mr Cox'.
      s = re.sub('Mr [\w]+ [A-Z][\w]+ ', '[humanname] ', s)
      s = re.sub('Mrs [\w]+ [A-Z][\w]+ ', '[humanname] ', s)
      s = re.sub('Mr [\w]+ ', '[humanname] ', s)
      s = re.sub('Mrs [\w]+ ', '[humanname] ', s)
      
      # Remove brackets and contents inside.
      s = re.sub('\(.*\) ', '', s)
      s = re.sub('\(', '', s)
      s = re.sub('\)', '', s)
      
      # Extract numbers to reduce the vocab size.
      s = re.sub('[0-9\.]+ ', '[number] ', s)  
      
      # Replace i.e., p.m., a.m. to reduce confusion on period.
      s = re.sub(' i\.e\.', ' for example', s)          
      s = re.sub(' p\.m\.', ' pm', s)   
      s = re.sub(' a\.m\.', ' am', s) 
      
      # Remove unprintable characters.
      s = filter(lambda x: x in printable, s)
      
      s = s.lower()
      
      # For every 3 sentences we cut a new line to simulate a paragraph.
      # Produce train/validation/test sets by 20:2:78
      r = randint(0,50)
      if r < 10:
        valid_lines += 1
        sep = '\n' if (valid_lines % 3) == 0 else ' '
        f_valid.write(s + sep)
      elif r == 11:
        test_lines += 1
        sep = '\n' if (test_lines % 3) == 0 else ' '
        f_test.write(s + sep)
      else:
        train_lines += 1
        sep = '\n' if (train_lines % 3) == 0 else ' '
        f_train.write(s + sep)


prep_data('/content/datalab/punctuation/data/*', '/content/datalab/punctuation/datapreped')

Training

Some of the code is ported from TensorFlow model PTB Language Model.


In [18]:
# We deal with limited punctuations only because of limited training data.
PUNCTUATIONS = (u'.', u',', u'?', u'!', u':')
# `n` means no punctuation.
TARGETS = list(PUNCTUATIONS) + ['n']
# Set vocab size to remove low frequency words. Roughly with 10000 vocab, words with less than three counts are excluded.
VOCAB_SIZE = 10000

In [19]:
"""Helper functions for reading input data."""

import collections
import os
import tensorflow as tf


def read_words(filename):
  """Read words from file.
  Args:
    filename: path to the file to read words from.
  Returns:
    Words split by white space.
  """
  with tf.gfile.GFile(filename, "r") as f:
    x = f.read().decode("utf-8").replace("\n", " <eos> ").split()
  if x[-1] != '<eos>':
    x.append('<eos>')

  indices = [i for i, w in enumerate(x) if w in PUNCTUATIONS]
  # The next word after a punctuation is an important signal. We switch the punctuation
  # with next word so it can be used as part of the context.
  for i in indices:
    x[i], x[i+1] = x[i+1], x[i]
  return x


def build_vocab(filename):
  """Build vocabulary from training data file.
  Args:
    filename: path to the file to read words from.
  Returns:
    A dict with key being words and value being indices.
  """
  x = read_words(filename)
  counter = collections.Counter(x)
  count_pairs = sorted(counter.items(), key=lambda a: (-a[1], a[0]))
  count_pairs = count_pairs[:VOCAB_SIZE-1]
  words, _ = list(zip(*count_pairs))
  word_to_id = dict(zip(words, range(len(words))))
  word_to_id['<unk>'] = VOCAB_SIZE - 1
  return word_to_id


def file_to_word_and_punc_ids(filename, word_to_id):
  """Produce indices from words in file. x are indices for words, and y are indices for punctuations.
  Args:
    filename: path to the file to read words from.
    word_to_id: the vocab to indices dict.
  Returns:
    A pair. First element is the words indices. Second element is the target punctuation indices.
  """
  x_words = read_words(filename)
  x_id = [word_to_id[w] if w in word_to_id else word_to_id['<unk>'] for w in x_words]
  target_to_id = {p:i for i, p in enumerate(TARGETS)}
  y_words = x_words[1:] + ['padding']
  y_puncts = ['n' if elem not in PUNCTUATIONS else elem for elem in y_words]
  y_id = [target_to_id[p] for p in y_puncts]
  return x_id, y_id


def content_to_word_ids(content, word_to_id):
  """Produce indices from words from a given string.
  Args:
    filename: path to the file to read words from.
    word_to_id: the vocab to indices dict.    
  Returns:
    Words indices.
  """
  x = content.decode("utf-8").replace("\n", " <eos> ").split()
  indices = [i for i, w in enumerate(x) if w in PUNCTUATIONS]
  for i in indices:
    x[i], x[i+1] = x[i+1], x[i]

  x_id = [word_to_id[w] if w in word_to_id else word_to_id['<unk>'] for w in x]
  return x_id

In [1]:
"""The training model. """

import tensorflow as tf
import json


class TrainingConfig(object):
  init_scale = 0.1
  learning_rate = 1.0
  max_grad_norm = 5
  num_layers = 2
  num_steps = 50
  hidden_size = 150
  max_epoch =20
  max_max_epoch = 25
  keep_prob = 0.5
  lr_decay = 0.7
  batch_size = 100


class TrainingInput(object):
  """The input data producer."""

  def _make_input_producer(self, raw_data, batch_size, num_steps, name=None):
    with tf.name_scope(name, "InputProducer"):
      raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32)

      data_len = tf.size(raw_data)
      batch_len = data_len // batch_size
      data = tf.reshape(raw_data[0 : batch_size * batch_len], [batch_size, batch_len])

      epoch_size = (batch_len - 1) // num_steps
      epoch_size = tf.identity(epoch_size, name="epoch_size")

      i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
      x = tf.strided_slice(data, [0, i * num_steps], [batch_size, (i + 1) * num_steps])
      x.set_shape([batch_size, num_steps])
      return x

  def __init__(self, config, data_x, data_y, name=None):
    self.epoch_size = ((len(data_x) // config.batch_size) - 1) // config.num_steps
    self.input_data = self._make_input_producer(data_x, config.batch_size, config.num_steps, name=name)
    self.targets = self._make_input_producer(data_y, config.batch_size, config.num_steps, name=name)


class PuctuationModel(object):
  """The Punctuation training/evaluation model."""

  def __init__(self, is_training, config, input_):
    self._input = input_
    batch_size = config.batch_size
    num_steps = config.num_steps
    size = config.hidden_size

    def lstm_cell():
      return tf.contrib.rnn.BasicLSTMCell(size, forget_bias=0.0, state_is_tuple=True)
    
    attn_cell = lstm_cell
    if is_training and config.keep_prob < 1:
      def attn_cell():
        return tf.contrib.rnn.DropoutWrapper(lstm_cell(), output_keep_prob=config.keep_prob)

    cell = tf.contrib.rnn.MultiRNNCell([attn_cell() for _ in range(config.num_layers)], state_is_tuple=True)
    self._initial_state = cell.zero_state(batch_size, tf.float32)
    embedding = tf.get_variable("embedding", [VOCAB_SIZE, size], dtype=tf.float32)
    inputs = tf.nn.embedding_lookup(embedding, input_.input_data)
    
    if is_training and config.keep_prob < 1:
      inputs = tf.nn.dropout(inputs, config.keep_prob)

    inputs = tf.unstack(inputs, num=num_steps, axis=1)
    outputs, state = tf.contrib.rnn.static_rnn(cell, inputs, initial_state=self._initial_state)

    output = tf.reshape(tf.concat(axis=1, values=outputs), [-1, size])
    softmax_w = tf.get_variable("softmax_w", [size, len(TARGETS)], dtype=tf.float32)
    softmax_b = tf.get_variable("softmax_b", [len(TARGETS)], dtype=tf.float32)
    logits = tf.matmul(output, softmax_w) + softmax_b
    self._predictions = tf.argmax(logits, 1)    
    self._targets = tf.reshape(input_.targets, [-1])
    loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
        [logits],
        [tf.reshape(input_.targets, [-1])],
        [tf.ones([batch_size * num_steps], dtype=tf.float32)])
    self._cost = cost = tf.reduce_sum(loss) / batch_size
    self._final_state = state

    if not is_training:
      return

    self._lr = tf.Variable(0.0, trainable=False)
    tvars = tf.trainable_variables()
    grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), config.max_grad_norm)
    optimizer = tf.train.GradientDescentOptimizer(self._lr)
    self._train_op = optimizer.apply_gradients(
        zip(grads, tvars),
        global_step=tf.contrib.framework.get_or_create_global_step())

    self._new_lr = tf.placeholder(tf.float32, shape=[], name="new_learning_rate")
    self._lr_update = tf.assign(self._lr, self._new_lr)

  def assign_lr(self, session, lr_value):
    session.run(self._lr_update, feed_dict={self._new_lr: lr_value})

  @property
  def input(self):
    return self._input

  @property
  def initial_state(self):
    return self._initial_state
  
  @property
  def final_state(self):
    return self._final_state  

  @property
  def cost(self):
    return self._cost

  @property
  def predictions(self):
    return self._predictions

  @property
  def targets(self):
    return self._targets

  @property
  def lr(self):
    return self._lr

  @property
  def train_op(self):
    return self._train_op

In [9]:
"""The trainer. """

import numpy as np

def run_epoch(session, model, num_steps, word_to_id, is_eval=False):
  """Runs the model on the given data for one epoch."""

  costs = 0.0
  iters = 0
  state = session.run(model.initial_state)

  fetches = {
      "cost": model.cost,
      "final_state": model.final_state,
      "predictions": model.predictions,
      "targets": model.targets,
  }
  if is_eval is False:
    fetches["train_op"] = model.train_op

  confusion_matrix = np.zeros(shape=(len(TARGETS),len(TARGETS)), dtype=np.int64)
  for step in range(model.input.epoch_size):
    feed_dict = {}
    # Set the state back to model after each run.
    for i, (c, h) in enumerate(model.initial_state):
      feed_dict[c] = state[i].c
      feed_dict[h] = state[i].h

    vals = session.run(fetches, feed_dict)
    cost = vals["cost"]
    state = vals["final_state"]
    targets = vals["targets"]
    predictions = vals['predictions']
    
    for t, p in zip(targets, predictions):
      confusion_matrix[t][p] += 1
    
    costs += cost
    iters += num_steps

  if is_eval is True:
    for i, t in enumerate(confusion_matrix):
      print('%s --- total: %d, correct: %d, accuracy: %.3f, ' % (TARGETS[i], sum(t), t[i], float(t[i]) / sum(t)))
      
  # Costs are calculated as cross-entropy loss.
  # Returns perplexity value (https://en.wikipedia.org/wiki/Perplexity), which is a common measurements on language models.
  return np.exp(costs / iters), confusion_matrix


def train(train_data_path, validation_data_path, save_path):
  """Train the model and save a checkpoint at the end."""
  
  word_to_id = build_vocab(train_data_path)
  train_data_x, train_data_y = file_to_word_and_punc_ids(train_data_path, word_to_id)
  valid_data_x, valid_data_y = file_to_word_and_punc_ids(validation_data_path, word_to_id)
  config = TrainingConfig()

  with tf.Graph().as_default():
    initializer = tf.random_uniform_initializer(-config.init_scale, config.init_scale)
    with tf.name_scope("Train"):
      train_input = TrainingInput(config=config, data_x=train_data_x, data_y=train_data_y, name="TrainInput")
      with tf.variable_scope("Model", reuse=None, initializer=initializer):
        train_model = PuctuationModel(is_training=True, config=config, input_=train_input)
      tf.summary.scalar("Training_Loss", train_model.cost)
      tf.summary.scalar("Learning_Rate", train_model.lr)

    with tf.name_scope("Valid"):
      valid_input = TrainingInput(config=config, data_x=valid_data_x, data_y=valid_data_y, name="ValidInput")
      with tf.variable_scope("Model", reuse=True, initializer=initializer):
        valid_model = PuctuationModel(is_training=False, config=config, input_=valid_input)
      tf.summary.scalar("Validation_Loss", valid_model.cost)

    sv = tf.train.Supervisor(logdir=save_path)
    with sv.managed_session() as session:
      for i in range(config.max_max_epoch):
        lr_decay = config.lr_decay ** max(i + 1 - config.max_epoch, 0.0)
        train_model.assign_lr(session, config.learning_rate * lr_decay)

        print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(train_model.lr)))
        train_perplexity, _ = run_epoch(session, train_model, config.num_steps, word_to_id)
        print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity))
        valid_perplexity, _ = run_epoch(session, valid_model, config.num_steps, word_to_id, is_eval=True)
        print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity))

      model_file_prefix = sv.saver.save(session, save_path, global_step=sv.global_step)

  word_to_id_file = os.path.join(os.path.dirname(save_path), 'word_to_id.json')
  with open(word_to_id_file, 'w') as outfile:
    json.dump(word_to_id, outfile)
  return model_file_prefix

In [21]:
# Delete the model directory if it exists so it always trains from beginning.
!rm -r -f /content/datalab/punctuation/model

Start training. Training takes about 20 ~ 30 minutes on a n1-standard-1 GCP VM.


In [22]:
model_dir = '/content/datalab/punctuation/model'
saved_model_path = model_dir + '/punctuation'
model_file_prefix = train(
  '/content/datalab/punctuation/datapreped/train.txt',
  '/content/datalab/punctuation/datapreped/valid.txt',
  saved_model_path)


Epoch: 1 Learning rate: 1.000
Epoch: 1 Train Perplexity: 2.610
. --- total: 9675, correct: 3234, accuracy: 0.334, 
, --- total: 10403, correct: 0, accuracy: 0.000, 
? --- total: 383, correct: 0, accuracy: 0.000, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 0, accuracy: 0.000, 
n --- total: 214200, correct: 214196, accuracy: 1.000, 
Epoch: 1 Valid Perplexity: 1.796
Epoch: 2 Learning rate: 1.000
Epoch: 2 Train Perplexity: 1.506
. --- total: 9675, correct: 7757, accuracy: 0.802, 
, --- total: 10403, correct: 2187, accuracy: 0.210, 
? --- total: 383, correct: 0, accuracy: 0.000, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 0, accuracy: 0.000, 
n --- total: 214200, correct: 208732, accuracy: 0.974, 
Epoch: 2 Valid Perplexity: 1.270
Epoch: 3 Learning rate: 1.000
Epoch: 3 Train Perplexity: 1.305
. --- total: 9675, correct: 7076, accuracy: 0.731, 
, --- total: 10403, correct: 5156, accuracy: 0.496, 
? --- total: 383, correct: 0, accuracy: 0.000, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 0, accuracy: 0.000, 
n --- total: 214200, correct: 207873, accuracy: 0.970, 
Epoch: 3 Valid Perplexity: 1.242
Epoch: 4 Learning rate: 1.000
Epoch: 4 Train Perplexity: 1.254
. --- total: 9675, correct: 7503, accuracy: 0.776, 
, --- total: 10403, correct: 5359, accuracy: 0.515, 
? --- total: 383, correct: 0, accuracy: 0.000, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 0, accuracy: 0.000, 
n --- total: 214200, correct: 208477, accuracy: 0.973, 
Epoch: 4 Valid Perplexity: 1.216
Epoch: 5 Learning rate: 1.000
Epoch: 5 Train Perplexity: 1.225
. --- total: 9675, correct: 5989, accuracy: 0.619, 
, --- total: 10403, correct: 1194, accuracy: 0.115, 
? --- total: 383, correct: 0, accuracy: 0.000, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 0, accuracy: 0.000, 
n --- total: 214200, correct: 214000, accuracy: 0.999, 
Epoch: 5 Valid Perplexity: 1.331
Epoch: 6 Learning rate: 1.000
Epoch: 6 Train Perplexity: 1.211
. --- total: 9675, correct: 7303, accuracy: 0.755, 
, --- total: 10403, correct: 6057, accuracy: 0.582, 
? --- total: 383, correct: 0, accuracy: 0.000, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 0, accuracy: 0.000, 
n --- total: 214200, correct: 209294, accuracy: 0.977, 
Epoch: 6 Valid Perplexity: 1.181
Epoch: 7 Learning rate: 1.000
Epoch: 7 Train Perplexity: 1.193
. --- total: 9675, correct: 7618, accuracy: 0.787, 
, --- total: 10403, correct: 5638, accuracy: 0.542, 
? --- total: 383, correct: 0, accuracy: 0.000, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 0, accuracy: 0.000, 
n --- total: 214200, correct: 210132, accuracy: 0.981, 
Epoch: 7 Valid Perplexity: 1.171
Epoch: 8 Learning rate: 1.000
Epoch: 8 Train Perplexity: 1.186
. --- total: 9675, correct: 6840, accuracy: 0.707, 
, --- total: 10403, correct: 2910, accuracy: 0.280, 
? --- total: 383, correct: 0, accuracy: 0.000, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 0, accuracy: 0.000, 
n --- total: 214200, correct: 213682, accuracy: 0.998, 
Epoch: 8 Valid Perplexity: 1.195
Epoch: 9 Learning rate: 1.000
Epoch: 9 Train Perplexity: 1.172
. --- total: 9675, correct: 7074, accuracy: 0.731, 
, --- total: 10403, correct: 3558, accuracy: 0.342, 
? --- total: 383, correct: 0, accuracy: 0.000, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 0, accuracy: 0.000, 
n --- total: 214200, correct: 213421, accuracy: 0.996, 
Epoch: 9 Valid Perplexity: 1.175
Epoch: 10 Learning rate: 1.000
Epoch: 10 Train Perplexity: 1.167
. --- total: 9679, correct: 7519, accuracy: 0.777, 
, --- total: 10415, correct: 6518, accuracy: 0.626, 
? --- total: 385, correct: 0, accuracy: 0.000, 
! --- total: 78, correct: 0, accuracy: 0.000, 
: --- total: 255, correct: 0, accuracy: 0.000, 
n --- total: 214188, correct: 209486, accuracy: 0.978, 
Epoch: 10 Valid Perplexity: 1.164
Epoch: 11 Learning rate: 1.000
Epoch: 11 Train Perplexity: 1.162
. --- total: 9675, correct: 6420, accuracy: 0.664, 
, --- total: 10403, correct: 3281, accuracy: 0.315, 
? --- total: 383, correct: 0, accuracy: 0.000, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 0, accuracy: 0.000, 
n --- total: 214200, correct: 213808, accuracy: 0.998, 
Epoch: 11 Valid Perplexity: 1.195
Epoch: 12 Learning rate: 1.000
Epoch: 12 Train Perplexity: 1.155
. --- total: 9675, correct: 7390, accuracy: 0.764, 
, --- total: 10403, correct: 6031, accuracy: 0.580, 
? --- total: 383, correct: 0, accuracy: 0.000, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 5, accuracy: 0.019, 
n --- total: 214200, correct: 211245, accuracy: 0.986, 
Epoch: 12 Valid Perplexity: 1.147
Epoch: 13 Learning rate: 1.000
Epoch: 13 Train Perplexity: 1.150
. --- total: 9675, correct: 7260, accuracy: 0.750, 
, --- total: 10403, correct: 5192, accuracy: 0.499, 
? --- total: 383, correct: 0, accuracy: 0.000, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 11, accuracy: 0.042, 
n --- total: 214200, correct: 212549, accuracy: 0.992, 
Epoch: 13 Valid Perplexity: 1.144
Epoch: 14 Learning rate: 1.000
Epoch: 14 Train Perplexity: 1.146
. --- total: 9675, correct: 7683, accuracy: 0.794, 
, --- total: 10403, correct: 6518, accuracy: 0.627, 
? --- total: 383, correct: 3, accuracy: 0.008, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 14, accuracy: 0.054, 
n --- total: 214200, correct: 210133, accuracy: 0.981, 
Epoch: 14 Valid Perplexity: 1.149
Epoch: 15 Learning rate: 1.000
Epoch: 15 Train Perplexity: 1.143
. --- total: 9675, correct: 7739, accuracy: 0.800, 
, --- total: 10403, correct: 6007, accuracy: 0.577, 
? --- total: 383, correct: 32, accuracy: 0.084, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 14, accuracy: 0.054, 
n --- total: 214200, correct: 211194, accuracy: 0.986, 
Epoch: 15 Valid Perplexity: 1.139
Epoch: 16 Learning rate: 1.000
Epoch: 16 Train Perplexity: 1.136
. --- total: 9675, correct: 7799, accuracy: 0.806, 
, --- total: 10403, correct: 6096, accuracy: 0.586, 
? --- total: 383, correct: 87, accuracy: 0.227, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 14, accuracy: 0.054, 
n --- total: 214200, correct: 211133, accuracy: 0.986, 
Epoch: 16 Valid Perplexity: 1.137
Epoch: 17 Learning rate: 1.000
Epoch: 17 Train Perplexity: 1.133
. --- total: 9675, correct: 7625, accuracy: 0.788, 
, --- total: 10403, correct: 5518, accuracy: 0.530, 
? --- total: 383, correct: 148, accuracy: 0.386, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 15, accuracy: 0.058, 
n --- total: 214200, correct: 212207, accuracy: 0.991, 
Epoch: 17 Valid Perplexity: 1.132
Epoch: 18 Learning rate: 1.000
Epoch: 18 Train Perplexity: 1.127
. --- total: 9675, correct: 7517, accuracy: 0.777, 
, --- total: 10403, correct: 4534, accuracy: 0.436, 
? --- total: 383, correct: 140, accuracy: 0.366, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 14, accuracy: 0.054, 
n --- total: 214200, correct: 213088, accuracy: 0.995, 
Epoch: 18 Valid Perplexity: 1.147
Epoch: 19 Learning rate: 1.000
Epoch: 19 Train Perplexity: 1.124
. --- total: 9675, correct: 7590, accuracy: 0.784, 
, --- total: 10403, correct: 5357, accuracy: 0.515, 
? --- total: 383, correct: 214, accuracy: 0.559, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 16, accuracy: 0.062, 
n --- total: 214200, correct: 212445, accuracy: 0.992, 
Epoch: 19 Valid Perplexity: 1.132
Epoch: 20 Learning rate: 1.000
Epoch: 20 Train Perplexity: 1.120
. --- total: 9675, correct: 7518, accuracy: 0.777, 
, --- total: 10420, correct: 4509, accuracy: 0.433, 
? --- total: 381, correct: 206, accuracy: 0.541, 
! --- total: 81, correct: 0, accuracy: 0.000, 
: --- total: 258, correct: 15, accuracy: 0.058, 
n --- total: 214185, correct: 213092, accuracy: 0.995, 
Epoch: 20 Valid Perplexity: 1.147
Epoch: 21 Learning rate: 0.700
Epoch: 21 Train Perplexity: 1.110
. --- total: 9675, correct: 7774, accuracy: 0.804, 
, --- total: 10403, correct: 6213, accuracy: 0.597, 
? --- total: 383, correct: 226, accuracy: 0.590, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 16, accuracy: 0.062, 
n --- total: 214200, correct: 211307, accuracy: 0.986, 
Epoch: 21 Valid Perplexity: 1.130
Epoch: 22 Learning rate: 0.490
Epoch: 22 Train Perplexity: 1.107
. --- total: 9675, correct: 7658, accuracy: 0.792, 
, --- total: 10403, correct: 6091, accuracy: 0.586, 
? --- total: 383, correct: 232, accuracy: 0.606, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 17, accuracy: 0.066, 
n --- total: 214200, correct: 211707, accuracy: 0.988, 
Epoch: 22 Valid Perplexity: 1.130
Epoch: 23 Learning rate: 0.343
Epoch: 23 Train Perplexity: 1.104
. --- total: 9675, correct: 7695, accuracy: 0.795, 
, --- total: 10403, correct: 6008, accuracy: 0.578, 
? --- total: 383, correct: 251, accuracy: 0.655, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 19, accuracy: 0.073, 
n --- total: 214200, correct: 211812, accuracy: 0.989, 
Epoch: 23 Valid Perplexity: 1.129
Epoch: 24 Learning rate: 0.240
Epoch: 24 Train Perplexity: 1.103
. --- total: 9675, correct: 7689, accuracy: 0.795, 
, --- total: 10403, correct: 5973, accuracy: 0.574, 
? --- total: 383, correct: 249, accuracy: 0.650, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 18, accuracy: 0.069, 
n --- total: 214200, correct: 211897, accuracy: 0.989, 
Epoch: 24 Valid Perplexity: 1.129
Epoch: 25 Learning rate: 0.168
Epoch: 25 Train Perplexity: 1.102
. --- total: 9675, correct: 7729, accuracy: 0.799, 
, --- total: 10403, correct: 5919, accuracy: 0.569, 
? --- total: 383, correct: 252, accuracy: 0.658, 
! --- total: 80, correct: 0, accuracy: 0.000, 
: --- total: 259, correct: 18, accuracy: 0.069, 
n --- total: 214200, correct: 211928, accuracy: 0.989, 
Epoch: 25 Valid Perplexity: 1.129

In epoch 1, the mode predicted almost everything to be 'n'. It makes sense because vast majority of targets is "no punctuation" for each word so betting on that gives good overal accuracy already, although useless.

Starting from epoch 2, it learned to predict some '.'. After epoch 10, it could predict about 50% of ','. Only after epoch 15 it started predicting some '?'. Unfortunately, it never predicted '!' well, probably because the difference between '.' and '!' is very subtle. It also had problems predicting ':', maybe because lack of training instances.

Start a tensorboard instance, and you will see the training/validation loss curves, as well as other stats.


In [23]:
# Start a tensorboard to see the curves in Datalab. 

from google.datalab.ml import TensorBoard
tb = TensorBoard.start(model_dir)


TensorBoard was started successfully with pid 32416. Click here to access it.

Tensorboard is good but the curves are not saved with notebook. We can use Datalab's library to list and plot events.


In [24]:
from google.datalab.ml import Summary
summary = Summary(model_dir)
summary.list_events()


Out[24]:
{u'Model/global_step/sec': {'/content/datalab/punctuation/model/punctuation'},
 u'Train/Learning_Rate': {'/content/datalab/punctuation/model/punctuation'},
 u'Train/TrainInput/input_producer': {'/content/datalab/punctuation/model/punctuation'},
 u'Train/TrainInput_1/input_producer': {'/content/datalab/punctuation/model/punctuation'},
 u'Train/Training_Loss': {'/content/datalab/punctuation/model/punctuation'},
 u'Valid/ValidInput/input_producer': {'/content/datalab/punctuation/model/punctuation'},
 u'Valid/ValidInput_1/input_producer': {'/content/datalab/punctuation/model/punctuation'},
 u'Valid/Validation_Loss': {'/content/datalab/punctuation/model/punctuation'}}

In [25]:
summary.plot(event_names=['Train/Training_Loss', 'Valid/Validation_Loss'])


From the curves above, we got the best validation results around step 4000, and then in some runs a little bit over-fitting after.

Evaluation

At this point, we are done with training, and evaluation starts from a saved checkpoint. We will reuse the PuctuationModel defined earlier since evaluation model and training model are mostly the same.


In [28]:
"""Run the model with some test data."""

import os


def run_eval(model_file_prefix, test_data_path):
  """Run evaluation on test data."""

  word_to_id_file = os.path.join(os.path.dirname(model_file_prefix), 'word_to_id.json')
  with open(word_to_id_file, 'r') as f:
    word_to_id = json.load(f)
  test_data_x, test_data_y = file_to_word_and_punc_ids(test_data_path, word_to_id)

  eval_config = TrainingConfig()
  eval_config.batch_size = 1
  eval_config.num_steps = 1

  with tf.Graph().as_default():
    with tf.name_scope("Test"):
      test_input = TrainingInput(config=eval_config, data_x=test_data_x, data_y=test_data_y, name="TestInput")
      with tf.variable_scope("Model", reuse=None):
        mtest = PuctuationModel(is_training=False, config=eval_config, input_=test_input)

    logdir=os.path.join(os.path.dirname(model_file_prefix), 'eval')        
    sv = tf.train.Supervisor(logdir=logdir)
    with sv.managed_session() as session:
      sv.saver.restore(session, model_file_prefix)
      test_perplexity, cm_data = run_epoch(session, mtest, 1, word_to_id, is_eval=True)
  return cm_data

View accuracy and confusion matrix.


In [29]:
from google.datalab.ml import ConfusionMatrix
from pprint import pprint

cm_data = run_eval(model_file_prefix, '/content/datalab/punctuation/datapreped/test.txt')
pprint(cm_data.tolist())
cm = ConfusionMatrix(cm_data, TARGETS)
cm.plot()


. --- total: 949, correct: 757, accuracy: 0.798, 
, --- total: 998, correct: 559, accuracy: 0.560, 
? --- total: 35, correct: 23, accuracy: 0.657, 
! --- total: 11, correct: 0, accuracy: 0.000, 
: --- total: 22, correct: 0, accuracy: 0.000, 
n --- total: 21059, correct: 20828, accuracy: 0.989, 
[[757, 26, 5, 0, 0, 161],
 [42, 559, 2, 0, 0, 395],
 [3, 0, 23, 0, 0, 9],
 [8, 1, 0, 0, 0, 2],
 [7, 5, 0, 0, 0, 10],
 [78, 147, 5, 0, 1, 20828]]

Confusion matrix after removing "no punctuation".


In [46]:
cm_data_puncuations = cm_data.tolist()
for i, r in enumerate(cm_data_puncuations):
  cm_data_puncuations[i] = r[:-1]
cm_data_puncuations = cm_data_puncuations[:-1]
ConfusionMatrix(cm_data_puncuations, TARGETS[:-1]).plot()


Many of the "," are mistakenly predicted as "no punctuation", probably because many times either with or without comma is correct in syntax. There are some confusions between "," and ".", meaning that the model "knows" it is a break in sentence, but mistakenly chose comma or period. 65% of question marks are predicted correctly. For that we can give credits to LSTM model because it can "remember" the beginning of a sentence (which, what, where, etc) even if it is long.

Prediction

Fun time. Let's try generating some puncuations on test data. We'll need to define a "Prediction Model". It is a simplified training model, with num_steps and batch_size both being 1, and no loss or training ops. But the model is "compatible" with the training model in the sense that they share same variables, and it can load a checkpoint produced in training.


In [47]:
import tensorflow as tf


class PredictModel(object):
  """The Prediction model."""

  def __init__(self, config):
    self._input = tf.placeholder(shape=[1, 1], dtype=tf.int64)
    size = config.hidden_size

    def lstm_cell():
      return tf.contrib.rnn.BasicLSTMCell(size, forget_bias=0.0, state_is_tuple=True)

    cell = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(config.num_layers)], state_is_tuple=True)
    self._initial_state = cell.zero_state(1, tf.float32)
    embedding = tf.get_variable("embedding", [VOCAB_SIZE, size], dtype=tf.float32)
    inputs = tf.nn.embedding_lookup(embedding, self._input)
    inputs = tf.unstack(inputs, num=1, axis=1)
    outputs, self._final_state = tf.contrib.rnn.static_rnn(cell, inputs, initial_state=self._initial_state)
    output = tf.reshape(tf.concat(axis=1, values=outputs), [-1, size])
    softmax_w = tf.get_variable("softmax_w", [size, len(TARGETS)], dtype=tf.float32)
    softmax_b = tf.get_variable("softmax_b", [len(TARGETS)], dtype=tf.float32)
    logits = tf.matmul(output, softmax_w) + softmax_b
    self._prediction = tf.argmax(logits, 1)

  @property
  def input(self):
    return self._input
  
  @property
  def initial_state(self):
    return self._initial_state

  @property
  def final_state(self):
    return self._final_state
  
  @property
  def prediction(self):
    return self._prediction

In [48]:
"""The Predictor that runs the prediction model."""

import json
import os
import random


class Predictor(object):
    
  def __init__(self, model_file_prefix):
    word_to_id_file = os.path.join(os.path.dirname(model_file_prefix), 'word_to_id.json')
    with open(word_to_id_file, 'r') as f:
      self._word_to_id = json.load(f)

    config = TrainingConfig()    
    with tf.Graph().as_default():
      with tf.variable_scope("Model", reuse=None):
        self._model = PredictModel(config=config)

      saver = tf.train.Saver()
      self._session = tf.Session()
      saver.restore(self._session, model_file_prefix)

  def _get_predicted_until_punc(self, min_steps, data_x):

    state = self._session.run(self._model.initial_state)
    fetches = {
        "final_state": self._model.final_state,
        "prediction": self._model.prediction,
    }
    predicted_puncs = []
    step = 0
    for x in data_x:
      feed_dict = {}
      for i, (c, h) in enumerate(self._model.initial_state):
        feed_dict[c] = state[i].c
        feed_dict[h] = state[i].h
      feed_dict[self._model.input] = [[x]]

      vals = self._session.run(fetches, feed_dict)
      state = vals["final_state"]
      prediction = vals["prediction"]
      predicted = TARGETS[prediction[0]]
      predicted_puncs.append(predicted)
      step += 1
      if predicted != 'n' and step > min_steps:
        break
    return predicted_puncs
  
  def _apply_puncts_to_original(self, original, inserted):
    current_index = 0
    punc_positions = {}
    for w in inserted.split():
      if w in PUNCTUATIONS:
        punc_positions[current_index] = w
      else:
        current_index += 1
    words = []
    for i, w in enumerate(original.split() + ['']):
      if i in punc_positions:
        words.append(punc_positions[i])
      words.append(w)

    return ' '.join(words)
          
  def predict(self, content):
    """Insert punctuations with given string."""

    content = content.strip().lower()
    for p in PUNCTUATIONS:
      content = content.replace(' ' + p, '')
    prediction_source = content
    prediction_result = ''
    
    content = '<eos> ' + content + ' <eos>'
    min_step = 0
    while True:
      data_x = content_to_word_ids(content, self._word_to_id)
      puncts = self._get_predicted_until_punc(min_step, data_x)
      if len(data_x) == len(puncts):
        content = content.replace('. <eos> ', '').replace(' <eos>', ' ' + puncts[-1]) + '\n'
        prediction_result = self._apply_puncts_to_original(prediction_source, content)
        break
      else:
        words1 = [self._word_to_id.keys()[self._word_to_id.values().index(data_x[index])] for index in range(len(puncts) - 1)]
        indices = [i for i, w in enumerate(words1) if w in PUNCTUATIONS]
        for i in indices:
          words1[i], words1[i-1] = words1[i-1], words1[i] 
        words2 = [self._word_to_id.keys()[self._word_to_id.values().index(data_x[index])] for index in range(len(puncts) - 1, len(data_x))]
        all_words = words1 + [puncts[-1]] + words2  
        content = ' '.join(all_words)  
        min_step = len(puncts)          
    
    return prediction_source, prediction_result

  def predict_from_test_file(self, filename, num_random_lines):
    """given a file from test file, pick some random lines and do prediction."""    

    num_lines = sum(1 for line in open(filename))
    with open(filename) as f:
      lines = random.sample(f.readlines(), num_random_lines)
    for line in lines:
      line = line.strip().lower()
      source, predicted = self.predict(line)
      yield line, source, predicted

  def close(self):
    self._session.close()

Let's play with three paragraphs. First and second are single sentences, the third one contains multiple sentences.


In [49]:
predictor = Predictor(model_file_prefix)
sources = [
  'i think it is a report which will for the most part be supported by my group',
  'so what is the european union doing about it',
  'we must work more rapidly towards achieving the targets stipulated ' + 
    'in the white paper for renewable energy sources as this would bring ' + 
    'about a massive reduction in greenhouse gases but in common with others ' + 
    ' we too are having to endure the greenhouse effect furthermore we should ' + 
    'utilise an extraordinary budget line since this is an extraordinarily catastrophic situation',
]
for s in sources:
  source, predicted = predictor.predict(s)
  print('\n---SOURCE----\n' + source)
  print('---PREDICTED----\n' + predicted)

predictor.close()


---SOURCE----
i think it is a report which will for the most part be supported by my group
---PREDICTED----
i think it is a report which will , for the most part , be supported by my group . 

---SOURCE----
so what is the european union doing about it
---PREDICTED----
so what is the european union doing about it ? 

---SOURCE----
we must work more rapidly towards achieving the targets stipulated in the white paper for renewable energy sources as this would bring about a massive reduction in greenhouse gases but in common with others  we too are having to endure the greenhouse effect furthermore we should utilise an extraordinary budget line since this is an extraordinarily catastrophic situation
---PREDICTED----
we must work more rapidly towards achieving the targets stipulated in the white paper for renewable energy sources . as this would bring about a massive reduction in greenhouse gases , but in common with others , we too are having to endure the greenhouse effect . furthermore , we should utilise an extraordinary budget line since this is an extraordinarily catastrophic situation . 

The last prediction is actually somewhat incorrect. It should be:

we must work more rapidly towards achieving the targets stipulated in the white paper for renewable energy sources , as this would bring about a massive reduction in greenhouse gases . but in common with others , we too are having to endure the greenhouse effect . furthermore , we should utilise an extraordinary budget line , since this is an extraordinarily catastrophic situation .

It mistakenly predicted the first period where it should be comma. I think we may improve it by showing more words instead of one after the punctuation, or doing it bidirectionally and mix both scores.

Below we try some data outside our test data (test data and training data are generated from the same data). The first two are common conversational questions, and third is from recent european parliament news.


In [50]:
predictor = Predictor(model_file_prefix)
sources = [
  'how are you',
  'where do you see yourself in five years',
  'last december the european commission proposed updating the existing customs union with ' + 
    'turkey and extending bilateral trade relations once negotiations have been completed ' + 
    'the agreement would still have to be approved by the Parliament before it could enter into force',
]
for s in sources:
  source, predicted = predictor.predict(s)
  print('\n---SOURCE----\n' + source)
  print('---PREDICTED----\n' + predicted)

predictor.close()


---SOURCE----
how are you
---PREDICTED----
how are you ? 

---SOURCE----
where do you see yourself in five years
---PREDICTED----
where do you see yourself in five years ? 

---SOURCE----
last december the european commission proposed updating the existing customs union with turkey and extending bilateral trade relations once negotiations have been completed the agreement would still have to be approved by the parliament before it could enter into force
---PREDICTED----
last december , the european commission proposed updating the existing customs union with turkey and extending bilateral trade relations once negotiations have been completed . the agreement would still have to be approved by the parliament before it could enter into force . 

As a convenience, the predictor can pick random sentences from a test files.


In [51]:
predictor = Predictor(model_file_prefix)
for t, s, p in predictor.predict_from_test_file('/content/datalab/punctuation/datapreped/test.txt', 3):
  print('\n---SOURCE----\n' + s)
  print('---PREDICTED----\n' + p)
  print('---TRUTH----\n' + t)
predictor.close()


---SOURCE----
i am also particularly glad that commissioner nielson is present here today this is why it is essential that international aid is not led by the media but only by the people in need the government seems keen to remove the opposition from local authorities where they have been entitled to be since the previous local elections
---PREDICTED----
i am also particularly glad that commissioner nielson is present here . today , this is why it is essential that international aid is not led by the media , but only by the people in need the government seems keen to remove the opposition from local authorities where they have been entitled to be since the previous local elections . 
---TRUTH----
i am also particularly glad that commissioner nielson is present here today . this is why it is essential that international aid is not led by the media but only by the people in need . the government seems keen to remove the opposition from local authorities , where they have been entitled to be since the previous local elections .

---SOURCE----
since france laid a huge amount of emphasis on legal exemption at the time it was damaged by concessions in agricultural policy finally and in conclusion [humanname] with the expiry of the ecsc treaty the regulations will have to be reviewed since i think that the aid system will have to continue beyond [number] and in that case i am in favour of a council regulation which will ensure security in this area what is the commissioners opinion of this
---PREDICTED----
since france laid a huge amount of emphasis on legal exemption at the time , it was damaged by concessions in agricultural policy . finally , and in conclusion , [humanname] , with the expiry of the ecsc treaty , the regulations will have to be reviewed since i think that the aid system will have to continue beyond [number] and in that case . i am in favour of a council regulation which will ensure security in this area . what is the commissioners opinion of this ? 
---TRUTH----
since france laid a huge amount of emphasis on legal exemption at the time , it was damaged by concessions in agricultural policy . finally , and in conclusion , [humanname] , with the expiry of the ecsc treaty , the regulations will have to be reviewed since i think that the aid system will have to continue beyond [number] , and in that case i am in favour of a council regulation which will ensure security in this area . what is the commissioners opinion of this ?

---SOURCE----
provide the commission with the necessary resources and recognition essential for it to carry out its work and for it to verify the real progress made at community level and at the level of the member states and we will be - and the council will be - able to deal with the decisions made at tampere and move towards a union which is freer more just and more secure the increase in the number of positive results demonstrates the efficiency of the system and there is a continuing increase in the number of reports i would now like to state before you here that we firmly intend to consult the european parliament not only in those cases stipulated by the treaty which we are obliged to do but also to inform and consult parliament whenever we believe that it is appropriate to widen this type of consultation and when in direct contact with the commission we believe that parliament should be able to present its opinion
---PREDICTED----
provide the commission with the necessary resources and recognition essential for it to carry out its work and for it to verify the real progress made at community level and at the level of the member states and we will be - and the council will be - able to deal with the decisions made at tampere and move towards a union which is freer more just and more secure . the increase in the number of positive results demonstrates the efficiency of the system and there is a continuing increase in the number of reports . i would now like to state before you here that we firmly intend to consult the european parliament , not only in those cases stipulated by the treaty , which we are obliged to do but also to inform and consult parliament whenever we believe that it is appropriate to widen this type of consultation and when , in direct contact with the commission , we believe that parliament should be able to present its opinion . 
---TRUTH----
provide the commission with the necessary resources and recognition essential for it to carry out its work and for it to verify the real progress made at community level and at the level of the member states , and we will be - and the council will be - able to deal with the decisions made at tampere and move towards a union which is freer , more just and more secure . the increase in the number of positive results demonstrates the efficiency of the system and there is a continuing increase in the number of reports . i would now like to state before you here that we firmly intend to consult the european parliament , not only in those cases stipulated by the treaty , which we are obliged to do , but also to inform and consult parliament whenever we believe that it is appropriate to widen this type of consultation , and when , in direct contact with the commission , we believe that parliament should be able to present its opinion .

Clean up


In [52]:
TensorBoard.stop(tb)

In [ ]: